package df

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, functions}

object E4_sql_udf {
  // 创建 Spark 运行配置对象
  val sparkConf = new SparkConf().setMaster("local[*]").setAppName("Hello")
  //创建 SparkSession 对象
  val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

  // 导入隐式转换

  import spark.implicits._

  def main(args: Array[String]): Unit = {
//    test_udf
//    test_udaf
    test_udafV2
  }

  def test_udf = {
    // 读取原始数据，创建DataFrame
    val data = Seq(
      (1, "hangge", 30),
      (2, "guge", 25),
      (3, "baige", 15)
    )
    val df: DataFrame = data.toDF("id", "name", "age")

    // 注册UDF
    spark.udf.register("firstLetterToUpper",
      (input: String) => input.substring(0, 1).toUpperCase() + input.substring(1))

    // 创建临时表
    df.createOrReplaceTempView("employees")

    // 使用SQL语法处理DataFrame并显示结果
    val resultDF = spark.sql("SELECT id, firstLetterToUpper(name) AS name_upper FROM employees")
    resultDF.show()

    //关闭 Spark
    spark.stop()
  }
  """
    |（1）UDAF（User-Defined Aggregate Functio）是用户自定义聚合函数
    |定义自己的聚合函数，从而实现复杂的聚合逻辑，如计算平均值、拼接字符串、自定义统计等。
    |
    |（2）UDAF 又可以分为自定义弱类型聚合函数（UserDefinedAggregateFunction）和强类型聚合函数（Aggregator）：
    |自定义弱类型聚合函数（UserDefinedAggregateFunction）是一个抽象类，我们需要继承它并实现一些方法来定义自己的聚合函数。这些方法包括 inputSchema、bufferSchema、dataType、initialize、update、merge 和 evaluate 等。但是，这种方法对数据类型的约束较少，因为数据在处理过程中通常以 Any 类型传递。
    |强类型聚合函数（Aggregator）是一个泛型类，它允许我们定义输入和缓冲区的数据类型，并在编译时进行类型检查。我们需要提供两个函数：zero 函数用于初始化聚合缓冲区，reduce 函数用于更新聚合缓冲区。
    |
    |""".stripMargin

  """
    |（1）下面我们使用强类型聚合函数计算一组学生的平均分和及格率，首先我们定义如下内容：
    |定义一个 Student 类来表示学生的数据
    |定义一个 StudentStats 类来表示学生的统计信息，其中包括平均分和及格率。
    |定义一个强类型聚合函数 StudentStatsAggregator，它使用一个三元组（Double，Long，Long）作为缓冲区，分别存储总分、及格人数和总人数。在 reduce 方法中，我们根据每个学生的成绩更新缓冲区的值。merge 方法用于合并两个缓冲区的值。在 finish 方法中，我们根据缓冲区的值计算平均分和及格率，并返回 StudentStats 对象。
    |
    |""".stripMargin

  def test_udaf = {
    // 读取原始数据，创建Dataset
    val ds = Seq(
      Student(1, "小刘", 85.0),
      Student(2, "小李", 62.5),
      Student(3, "老余", 90.0),
      Student(4, "老杨", 45.0)
    ).toDS()

    // 创建聚合函数
    val studentStatsAggregator = new StudentStatsAggregator

    //将聚合函数转换为查询的列进行查询
    val stats = ds.select(studentStatsAggregator.toColumn).as[StudentStats].first()
    println("平均分: " + stats.averageScore)
    println("及格率: " + stats.passRate)

    //关闭 Spark
    spark.stop()
  }

  def test_udafV2 = {
    // 读取原始数据，创建Dataset
    val ds = Seq(
      Student(1, "小刘", 85.0),
      Student(2, "小李", 62.5),
      Student(3, "老余", 90.0),
      Student(4, "老杨", 45.0)
    ).toDS()

    // 创建聚合函数
    val studentStatsAggregator = new StudentStatsAggregator
    // 注册UDF
    spark.udf.register("studentStatsAggregator", functions.udaf(studentStatsAggregator))

    // 创建临时表
    ds.createOrReplaceTempView("students")

    // 使用SQL语法进行查询
    val resultDS: Dataset[Row] = spark.sql(
      "SELECT studentStatsAggregator(*) as stats, " +
        "stats.averageScore as averageScore, " +
        "stats.passRate as passRate " +
        "FROM students"
    )
    resultDS.show()

    // 将 Dataset[Row] 转换成 Dataset[StudentStats]
    val statsDS: Dataset[StudentStats] = resultDS.as[StudentStats]
    statsDS.show()

    // 获取 StudentStats
    val stats = statsDS.first()
    println("平均分: " + stats.averageScore)
    println("及格率: " + stats.passRate)

    //关闭 Spark
    spark.stop()
  }

  class StudentStatsAggregator extends Aggregator[Student, (Double, Long, Long), StudentStats] {
    // 初始缓冲区的值，由三元组 (总分，及格人数，总人数) 组成
    override def zero: (Double, Long, Long) = (0.0, 0L, 0L)

    // 更新缓冲区的值，根据每个学生的成绩进行更新
    override def reduce(buff: (Double, Long, Long), student: Student): (Double, Long, Long) = {
      val (totalScore, passCount, totalCount) = buff
      val newTotalScore = totalScore + student.score
      val newPassCount = if (student.score >= 60) passCount + 1 else passCount
      val newTotalCount = totalCount + 1
      (newTotalScore, newPassCount, newTotalCount)
    }

    override def merge(b1: (Double, Long, Long), b2: (Double, Long, Long)): (Double, Long, Long) = {
      (b1._1 + b2._1, b1._2 + b2._2, b1._3 + b2._3)
    }

    override def finish(reduction: (Double, Long, Long)): StudentStats = {
      val (totalScore, passCnt, totalCnt) = reduction
      val avgScore = totalScore / totalCnt
      val passRate = passCnt / totalCnt
      StudentStats(avgScore, passRate)
    }

    override def bufferEncoder: Encoder[(Double, Long, Long)] = Encoders.product

    override def outputEncoder: Encoder[StudentStats] = Encoders.product
  }

  // 表示学生的数据结构
  case class Student(id: Int, name: String, score: Double)

  // 表示学生统计信息的数据结构，包括平均分和及格率
  case class StudentStats(averageScore: Double, passRate: Double)
}
